

# ################################################
# Unified: reparam + EL
# ################################################

optimize_unified_reparam <- function(dat, fmla, px, beta_start, opt){
 
 fmla_m = fmla$fmla_m
 fmla_l = fmla$fmla_l
 fmla_f = fmla$fmla_f
 
 # Prepare the data
 Xm = as.matrix(model.matrix(fmla_m, data=model.frame(dat, na.action = NULL)))
 Xl = as.matrix(model.matrix(fmla_l, data=model.frame(dat, na.action = NULL)))
 Xf = as.matrix(model.matrix(fmla_f, data=model.frame(dat, na.action = NULL)))
 
 # Initial values for beta
 if (length(beta_start) == 0){
  beta_start = rep(0.1, ncol(Xm) + ncol(Xl) + ncol(Xf) + 2) # added 2 for w_0 and w_a
  names(beta_start) = c(colnames(Xm), olnames(Xl), colnames(Xf), "w0", "wa")
 }
 
 # Define the negative log likelihood function   
 eval_f <- function(beta, dat, Xm, Xl, Xf, px, opt){
  # cat("1 \n")
  n = nrow(dat)
  p = length(beta)
  beta_m = beta[1:ncol(Xm)]
  beta_l = beta[(ncol(Xm)+1):(ncol(Xm)+ncol(Xl))]
  beta_f = beta[(ncol(Xm)+ncol(Xl)+1):(p-2)]
  w0 = beta[p-1]
  wa = beta[p]
  names(beta_m) = colnames(Xm)
  names(beta_l) = colnames(Xl)
  names(beta_f) = colnames(Xf)
  names(w0) = c("w0")
  names(wa) = c("wa")
  beta_y = c(beta_f, w0, wa)
  M = dat$M
  L = dat$L
  Y = dat$Y
  
  Y_hat = estimate_Y(dat, beta_y, beta_l, beta_m, px)
  p_Y = dnorm(Y, Y_hat, 1)
  
  p_L1 = 1/(1+exp(-Xl%*%beta_l))
  p_L = L*p_L1 + (1-L)*(1-p_L1)
  
  p_M1 = 1/(1+exp(-Xm%*%beta_m))
  p_M = M*p_M1 + (1-M)*(1-p_M1)
  
  f = sum(log(p_M) + log(p_L) + log(p_Y))
  
  # f = sum(-M*log(1+exp(-Xm%*%beta_m))-(1-M)*log(1+exp(Xm%*%beta_m))) + sum(-(Y - Y_hat)^2/2)
  return(-f/n)
 }
 
 # Define the inequlity constraint 
 eval_g_ineq <- function(beta, dat, Xm, Xl, Xf, px, opt){
  
  tau_u = opt$tau_u
  tau_l = opt$tau_l
  
  n = nrow(dat)
  p = length(beta)
  beta_m = beta[1:ncol(Xm)]
  beta_l = beta[(ncol(Xm)+1):(ncol(Xm)+ncol(Xl))]
  beta_f = beta[(ncol(Xm)+ncol(Xl)+1):(p-2)]
  w0 = beta[p-1]
  wa = beta[p]
  # names(beta_m) = colnames(Xm)
  # names(beta_l) = colnames(Xl)
  # names(beta_f) = colnames(Xf)
  # names(w0) = c("w0")
  # names(wa) = c("wa")
  # beta_y = c(beta_f, w0, wa)
  # 
  # beta_par = list(beta_y=beta_y, beta_l=beta_l, beta_m=beta_m, beta_a=NULL)
  # pse = func(dat, beta_par, px, opt)
  pse = wa
  
  eval_g =  c(pse - tau_u, tau_l - pse)
  return(eval_g)
 }
 
 
 # Solve the optimization problem
 mle_res = nloptr(x0=beta_start, 
                  eval_f=eval_f, 
                  eval_g_ineq=eval_g_ineq,
                  opts = list("algorithm"="NLOPT_LN_COBYLA","xtol_rel"=1.0e-3, "maxeval"=5000),
                  dat=dat, Xm=Xm, Xl=Xl, Xf=Xf, px=px, opt=opt)
 
 # Returnt the parameters
 beta = mle_res$solution
 p = length(beta)
 
 beta_m = beta[1:ncol(Xm)]
 beta_l = beta[(ncol(Xm)+1):(ncol(Xm)+ncol(Xl))]
 beta_y = beta[(ncol(Xm)+ncol(Xl)+1):p]
 names(beta_m) = colnames(Xm)
 names(beta_l) = colnames(Xl)
 names(beta_y) = c(colnames(Xf), "w0", "wa")
 
 return(list(beta_m = beta_m, 
             beta_l = beta_l, 
             beta_y = beta_y))
}

# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

get_m_unified <- function(beta, dat, px){
 
 n = nrow(dat)
 
 beta_m = beta$beta_m
 beta_l = beta$beta_l
 beta_y = beta$beta_y
 
 p = length(beta_y)
 beta_f = beta_y[1:(p-2)]
 w0 = beta_y[p-1]
 wa = beta_y[p]
 
 dat_a0m0l0 = process_data(dat, a = 0, m = 0, l = 0)
 dat_a0m0l1 = process_data(dat, a = 0, m = 0, l = 1)
 dat_a0m1l0 = process_data(dat, a = 0, m = 1, l = 0)
 dat_a0m1l1 = process_data(dat, a = 0, m = 1, l = 1)
 dat_a1m0l0 = process_data(dat, a = 1, m = 0, l = 0)
 dat_a1m0l1 = process_data(dat, a = 1, m = 0, l = 1)
 dat_a1m1l0 = process_data(dat, a = 1, m = 1, l = 0)
 dat_a1m1l1 = process_data(dat, a = 1, m = 1, l = 1)
 
 dat_a0m0 = process_data(dat, a = 0, m = 0, l = dat$L)
 dat_a0m1 = process_data(dat, a = 0, m = 1, l = dat$L)
 dat_a1m0 = process_data(dat, a = 1, m = 0, l = dat$L)
 dat_a1m1 = process_data(dat, a = 1, m = 1, l = dat$L)
 
 dat_a0 = process_data(dat, a = 0, m = dat$M, l = dat$L)
 dat_a1 = process_data(dat, a = 1, m = dat$M, l = dat$L)
 
 # p(L | M, A, C)
 idx_l = match(attributes(beta_l)$names, colnames(dat))
 p_l1a0m0 = 1/(1 + exp(-dat_a0m0[, idx_l]%*%beta_l))
 p_l0a0m0 = 1 - p_l1a0m0
 p_l1a0m1 = 1/(1 + exp(-dat_a0m1[, idx_l]%*%beta_l))
 p_l0a0m1 = 1 - p_l1a0m1
 p_l1a1m0 = 1/(1 + exp(-dat_a1m0[, idx_l]%*%beta_l))
 p_l0a1m0 = 1 - p_l1a1m0
 p_l1a1m1 = 1/(1 + exp(-dat_a1m1[, idx_l]%*%beta_l))
 p_l0a1m1 = 1 - p_l1a1m1
 
 # p(M | A, C)
 idx_m = match(attributes(beta_m)$names, colnames(dat))
 p_m1a0 = 1/(1 + exp(-dat_a0[, idx_m]%*%beta_m))
 p_m0a0 = 1 - p_m1a0
 p_m1a1 = 1/(1 + exp(-dat_a1[, idx_m]%*%beta_m))
 p_m0a1 = 1 - p_m1a1
 
 # 1/n { \sum_i \sum_{m, l} {f_lmAc*p(L=l|A=0,M=m,c)*p(M=m|A=1,c)} }
 idx_f = match(attributes(beta_f)$names, colnames(dat))
 f_l0m0a0 = dat_a0m0l0[, idx_f]%*%beta_f
 f_l0m0a1 = dat_a1m0l0[, idx_f]%*%beta_f
 f_l0m1a0 = dat_a0m1l0[, idx_f]%*%beta_f
 f_l0m1a1 = dat_a1m1l0[, idx_f]%*%beta_f
 f_l1m0a0 = dat_a0m0l1[, idx_f]%*%beta_f
 f_l1m0a1 = dat_a1m0l1[, idx_f]%*%beta_f
 f_l1m1a0 = dat_a0m1l1[, idx_f]%*%beta_f
 f_l1m1a1 = dat_a1m1l1[, idx_f]%*%beta_f
 
 f_A1 = f_l0m0a1*p_l0a0m0*p_m0a1 + f_l0m1a1*p_l0a0m1*p_m1a1 + f_l1m0a1*p_l1a0m0*p_m0a1 + f_l1m1a1*p_l1a0m1*p_m1a1
 f_A0 = f_l0m0a0*p_l0a0m0*p_m0a1 + f_l0m1a0*p_l0a0m1*p_m1a1 + f_l1m0a0*p_l1a0m0*p_m0a1 + f_l1m1a0*p_l1a0m1*p_m1a1
 
 # E[Y | A = 1, M = 0, L = 0, C]
 y_a1m0l0 = f_l0m0a1 - sum(px*f_A1) + w0 + wa
 
 # E[Y | A = 0, M = 0, L = 0, C]
 y_a0m0l0 = f_l0m0a0 - sum(px*f_A0) + w0
 
 # E[Y | A = 1, M = 0, L = 1, C]
 y_a1m0l1 = f_l1m0a1 - sum(px*f_A1) + w0 + wa
 
 # E[Y | A = 0, M = 0, L = 1, C]
 y_a0m0l1 = f_l1m0a0 - sum(px*f_A0) + w0
 
 # E[Y | A = 1, M = 1, L = 0, C]
 y_a1m1l0 = f_l0m1a1 - sum(px*f_A1) + w0 + wa
 
 # E[Y | A = 0, M = 1, L = 0, C]
 y_a0m1l0 = f_l0m1a0 - sum(px*f_A0) + w0
 
 # E[Y | A = 1, M = 1, L = 1, C]
 y_a1m1l1 = f_l1m1a1 - sum(px*f_A1) + w0 + wa
 
 # E[Y | A = 0, M = 1, L = 0, C]
 y_a0m1l1 = f_l1m1a0 - sum(px*f_A0) + w0

m = ( (y_a1m0l0*p_m0a1 - y_a0m0l0*p_m0a0)*p_l0a0m0 + 
       (y_a1m0l1*p_m0a1 - y_a0m0l1*p_m0a0 )*p_l1a0m0 + 
       (y_a1m1l0*p_m1a1 - y_a0m1l0*p_m1a0)*p_l0a0m1 + 
       (y_a1m1l1*p_m1a1 - y_a0m1l1*p_m1a0 )*p_l1a0m1 )

 return(m)
}


# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

optimize_unified <- function(dat, fmla, px, beta_start, opt){
 
 n = nrow(dat)
 fmla_m = fmla$fmla_m
 fmla_l = fmla$fmla_l
 fmla_f = fmla$fmla_f
 
 max_iter = opt$max_iter
 threshold = opt$threshold
 
 # Prepare the data
 Xm = as.matrix(model.matrix(fmla_m, data=model.frame(dat, na.action = NULL)))
 Xl = as.matrix(model.matrix(fmla_l, data=model.frame(dat, na.action = NULL)))
 Xf = as.matrix(model.matrix(fmla_f, data=model.frame(dat, na.action = NULL)))
 
 # Initial values for beta
 if (length(beta_start) == 0){
  beta_start = rep(0.1, ncol(Xm) + ncol(Xl) + ncol(Xf) + 2) # added 1 for w_0
  names(beta_start) = c(colnames(Xm), colnames(Xl), colnames(Xf), "w0", "wa")
  beta_start = list(beta_m = beta_start[1:ncol(Xm)], 
                    beta_l = beta_start[(ncol(Xm)+1):(ncol(Xm)+ncol(Xl))], 
                    beta_y = beta_start[(ncol(Xm)+ncol(Xl)+1):length(beta_start)])
 }
 
 for (i in 1:max_iter){
  
  cat("iter: ", i, "\n")
  
  m = get_m_unified(beta_start, dat, px)
  lambda = get_lambda(m, dat)
  pi = get_pi(m, lambda, dat) 
  
  beta_m_st = beta_start$beta_m
  beta_l_st = beta_start$beta_l
  beta_y_st = beta_start$beta_y #[-length(beta_start$beta_y)]
  beta_start = c(beta_m_st, beta_l_st, beta_y_st)
  
  beta_up = optimize_unified_reparam(dat, fmla, pi, beta_start, opt)
  beta_m_up = beta_up$beta_m
  beta_l_up = beta_up$beta_l
  beta_y_up = beta_up$beta_y[-length(beta_up$beta_y)]
  
  err_beta_m = sqrt(sum((beta_m_st - beta_m_up)^2))
  err_beta_l = sqrt(sum((beta_l_st - beta_l_up)^2))
  err_beta_y = sqrt(sum((beta_y_st - beta_y_up)^2))
  err_pi = sqrt(sum((pi - px)^2))
  
  # cat("err_pi", err_pi, "err_beta_m", err_beta_m, "err_beta_y", err_beta_y, "\n")
  cat("nde = ", sum(px*m), "\n")
  
  # if (sum(px*m) < 0.0001){
  # if (err_pi < threshold && err_beta_m < threshold && err_beta_y < threshold){
  if (err_pi < 0.001 && err_beta_m < 0.005 && err_beta_l < 0.005 && err_beta_y < 0.005){
  # if (err_pi < 0.0007){
   break 
  }else{
   beta_start = beta_up
   px = pi
  }
 }
 
 beta = beta_up
 beta_m = beta$beta_m
 beta_l = beta$beta_l
 beta_y = beta$beta_y
 
 px = pi
 m = get_m_unified(beta, dat, px)
 nde = sum(px*m)
 
 # Compute log likelihood
 Yhat = estimate_Y(dat, beta_y, beta_l, beta_m, px)

 return(list(beta_m=beta_m,
             beta_y=beta_y,
             px=px, 
             Yhat=Yhat, 
             nde = nde))
}





